{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# ExternalSource operator\n",
    "\n",
    "In this example, we will see how to use `ExternalSource` operator with PyTorch DALI iterator, that allows us to\n",
    "use an external data source as an input to the Pipeline.\n",
    "\n",
    "In order to achieve that, we have to define a Iterator or Generator class which `next` function will\n",
    "return one or several `numpy` arrays."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from __future__ import division\n",
    "import types\n",
    "import collections\n",
    "import numpy as np\n",
    "from random import shuffle\n",
    "from nvidia.dali.pipeline import Pipeline\n",
    "import nvidia.dali.ops as ops            \n",
    "import nvidia.dali.types as types\n",
    "\n",
    "batch_size = 3\n",
    "epochs = 3"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Defining the iterator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ExternalInputIterator(object):\n",
    "    def __init__(self, batch_size, device_id, num_gpus):\n",
    "        self.images_dir = \"../../data/images/\"\n",
    "        self.batch_size = batch_size\n",
    "        with open(self.images_dir + \"file_list.txt\", 'r') as f:\n",
    "            self.files = [line.rstrip() for line in f if line is not '']\n",
    "        # whole data set size\n",
    "        self.data_set_len = len(self.files) \n",
    "        # based on the device_id and total number of GPUs - world size\n",
    "        # get proper shard\n",
    "        self.files = self.files[self.data_set_len * device_id // num_gpus:\n",
    "                                self.data_set_len * (device_id + 1) // num_gpus]\n",
    "        self.n = len(self.files)\n",
    "\n",
    "    def __iter__(self):\n",
    "        self.i = 0\n",
    "        shuffle(self.files)\n",
    "        return self\n",
    "\n",
    "    def __next__(self):\n",
    "        batch = []\n",
    "        labels = []\n",
    "\n",
    "        if self.i >= self.n:\n",
    "            raise StopIteration\n",
    "\n",
    "        for _ in range(self.batch_size):\n",
    "            jpeg_filename, label = self.files[self.i].split(' ')\n",
    "            f = open(self.images_dir + jpeg_filename, 'rb')\n",
    "            batch.append(np.frombuffer(f.read(), dtype = np.uint8))\n",
    "            labels.append(np.array([label], dtype = np.uint8))\n",
    "            self.i = (self.i + 1) % self.n\n",
    "        return (batch, labels)\n",
    "\n",
    "    @property\n",
    "    def size(self,):\n",
    "        return self.data_set_len\n",
    "\n",
    "    next = __next__\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Defining the pipeline\n",
    "\n",
    "Now the pipeline itself will be defined. First of all, a framework iterator will be used so we need to make sure that images and the output of the pipeline are uniforms in size, so resize operator is used. Also, `iter_setup` will raise the StopIteration exception when the AdvancedExternalInputIterator run of data. Worth notice is that iterator needs to be recreated so next time `iter_setup` is called it has ready data to consume."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ExternalSourcePipeline(Pipeline):\n",
    "    def __init__(self, batch_size, num_threads, device_id, external_data):\n",
    "        super(ExternalSourcePipeline, self).__init__(batch_size,\n",
    "                                      num_threads,\n",
    "                                      device_id,\n",
    "                                      seed=12)\n",
    "        self.input = ops.ExternalSource()\n",
    "        self.input_label = ops.ExternalSource()\n",
    "        self.decode = ops.ImageDecoder(device = \"mixed\", output_type = types.RGB)\n",
    "        self.res = ops.Resize(device=\"gpu\", resize_x=240, resize_y=240)\n",
    "        self.cast = ops.Cast(device = \"gpu\",\n",
    "                             dtype = types.UINT8)\n",
    "        self.external_data = external_data\n",
    "        self.iterator = iter(self.external_data)\n",
    "\n",
    "    def define_graph(self):\n",
    "        self.jpegs = self.input()\n",
    "        self.labels = self.input_label()\n",
    "        images = self.decode(self.jpegs)\n",
    "        images = self.res(images)\n",
    "        output = self.cast(images)\n",
    "        return (output, self.labels)\n",
    "\n",
    "    def iter_setup(self):\n",
    "        try:\n",
    "            (images, labels) = self.iterator.next()\n",
    "            self.feed_input(self.jpegs, images)\n",
    "            self.feed_input(self.labels, labels)\n",
    "        except StopIteration:\n",
    "            self.iterator = iter(self.external_data)\n",
    "            raise StopIteration"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Using the pipeline\n",
    "\n",
    "At the end let us see how it works. Please also notice the usage of `last_batch_padded` that tell iterator that the difference between data set size and batch size alignment is padded by real data that could be skipped at when provided to the framework (`fill_last_batch`):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 0, iter 0, real batch size: 3\n",
      "epoch: 0, iter 1, real batch size: 3\n",
      "epoch: 0, iter 2, real batch size: 3\n",
      "epoch: 0, iter 3, real batch size: 3\n",
      "epoch: 0, iter 4, real batch size: 3\n",
      "epoch: 0, iter 5, real batch size: 3\n",
      "epoch: 0, iter 6, real batch size: 3\n",
      "epoch: 1, iter 0, real batch size: 3\n",
      "epoch: 1, iter 1, real batch size: 3\n",
      "epoch: 1, iter 2, real batch size: 3\n",
      "epoch: 1, iter 3, real batch size: 3\n",
      "epoch: 1, iter 4, real batch size: 3\n",
      "epoch: 1, iter 5, real batch size: 3\n",
      "epoch: 1, iter 6, real batch size: 3\n",
      "epoch: 2, iter 0, real batch size: 3\n",
      "epoch: 2, iter 1, real batch size: 3\n",
      "epoch: 2, iter 2, real batch size: 3\n",
      "epoch: 2, iter 3, real batch size: 3\n",
      "epoch: 2, iter 4, real batch size: 3\n",
      "epoch: 2, iter 5, real batch size: 3\n",
      "epoch: 2, iter 6, real batch size: 3\n"
     ]
    }
   ],
   "source": [
    "from nvidia.dali.plugin.pytorch import DALIClassificationIterator as PyTorchIterator\n",
    "\n",
    "eii = ExternalInputIterator(batch_size, 0, 1)\n",
    "pipe = ExternalSourcePipeline(batch_size=batch_size, num_threads=2, device_id = 0,\n",
    "                              external_data = eii)\n",
    "pii = PyTorchIterator(pipe, size=eii.size, last_batch_padded=True, fill_last_batch=False)\n",
    "\n",
    "for e in range(epochs):\n",
    "    for i, data in enumerate(pii):\n",
    "        print(\"epoch: {}, iter {}, real batch size: {}\".format(e, i, len(data[0][\"data\"])))\n",
    "    pii.reset()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.15+"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
